#!pip install -r libraries_to_install.txt
import os
import random
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from pathlib import Path
from statistics import mean
import torchvision.models as models
from easyfsl.methods.utils import evaluate
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import MultiStepLR
from easyfsl.methods import PrototypicalNetworks, FewShotClassifier
from easyfsl.modules import resnet12
from easyfsl.datasets import PLANT
from easyfsl.samplers import TaskSampler
from torch.utils.data import DataLoader
import matplotlib.image as mpimg
from PIL import Image
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
if torch.cuda.is_available()==True:
print('GPUs are available! ')
else:
print('Please configure GPSs are not available')
GPUs are available!
img = mpimg.imread('./data/PLANT/100/DSC05982.jpg')
print(img.shape)
plt.imshow(img)
(256, 256, 3)
<matplotlib.image.AxesImage at 0x1e3f3be6670>
img = mpimg.imread('./data/PLANT/256/DSC09062.jpg')
print(img.shape)
plt.imshow(img)
(256, 256, 3)
<matplotlib.image.AxesImage at 0x1e3f3c65340>
img = mpimg.imread('./data/PLANT/316/DSC03840.jpg')
print(img.shape)
plt.imshow(img)
(256, 256, 3)
<matplotlib.image.AxesImage at 0x1e3f3cdc520>
img = mpimg.imread('./data/PLANT/330/DSC06136.jpg')
print(img.shape)
plt.imshow(img)
(256, 256, 3)
<matplotlib.image.AxesImage at 0x1e3f3d5b4c0>
img = mpimg.imread('./data/PLANT/348/DSC01037.jpg')
print(img.shape)
plt.imshow(img)
(256, 256, 3)
<matplotlib.image.AxesImage at 0x1e3f5dce6d0>
img = mpimg.imread('./data/PLANT/370/DSC01163.jpg')
print(img.shape)
plt.imshow(img)
(256, 256, 3)
<matplotlib.image.AxesImage at 0x1e3f5e41bb0>
images_data = []
for i in os.listdir('./data/PLANT/110/')[0:10]:
split = i.split('_')
images_data.append(Image.open('./data/PLANT/110/' + i))
plt.figure(figsize=(10,10))
for i in range(10):
plt.subplot(5,2,i+1)
plt.imshow(images_data[i])
plt.show()
images_data = []
for i in os.listdir('./data/PLANT/150/')[0:10]:
split = i.split('_')
images_data.append(Image.open('./data/PLANT/150/' + i))
plt.figure(figsize=(10,10))
for i in range(10):
plt.subplot(5,2,i+1)
plt.imshow(images_data[i])
plt.show()
images_data = []
for i in os.listdir('./data/PLANT/200/')[0:10]:
split = i.split('_')
images_data.append(Image.open('./data/PLANT/200/' + i))
plt.figure(figsize=(10,10))
for i in range(10):
plt.subplot(5,2,i+1)
plt.imshow(images_data[i])
plt.show()
images_data = []
for i in os.listdir('./data/PLANT/370/')[0:10]:
split = i.split('_')
images_data.append(Image.open('./data/PLANT/370/' + i))
plt.figure(figsize=(10,10))
for i in range(10):
plt.subplot(5,2,i+1)
plt.imshow(images_data[i])
plt.show()
Training Data
We use training data when we train the models. We feed train data to machine learning and deep learning models so that model can learn from the data.
Validation Data
We use validation data while training the model. We use this data to evalaute the performance that how the model perform on training time.
Testing Data
We use testing data after training the model. We use this data to evalaute the performance that how the model perform after training. So in this way first we get predictions from the trained model without giving the labels and then we compare the true labels with predictions and get the performance of th model..
n_way = 271 # number of classes
n_shot = 1 #number of samples
n_query = 1 # Number of images per class in the query set
DEVICE = "cuda"
n_workers = 5
n_tasks_per_epoch = 2000
n_validation_tasks = 2000
train_set = PLANT(split="train", training=True)
val_set = PLANT(split="test", training=False)
train_sampler = TaskSampler(
train_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_tasks_per_epoch
)
val_sampler = TaskSampler(
val_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
)
train_loader = DataLoader(
train_set,
batch_sampler=train_sampler,
num_workers=n_workers,
pin_memory=True,
collate_fn=train_sampler.episodic_collate_fn,
)
val_loader = DataLoader(
val_set,
batch_sampler=val_sampler,
num_workers=n_workers,
pin_memory=True,
collate_fn=val_sampler.episodic_collate_fn,
)
tranfer_learning_model = models.googlenet(pretrained=True)
few_shot_classifier = PrototypicalNetworks(tranfer_learning_model).to(DEVICE)
LOSS_FUNCTION = nn.CrossEntropyLoss()
n_epochs = 50
scheduler_milestones = [20, 30]
scheduler_gamma = 0.1
learning_rate = 1e-2
train_optimizer = SGD(
few_shot_classifier.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4
)
train_scheduler = MultiStepLR(
train_optimizer,
milestones=scheduler_milestones,
gamma=scheduler_gamma,
)
def training_epoch(
model: FewShotClassifier, data_loader: DataLoader, optimizer: Optimizer
):
all_loss = []
model.train()
with tqdm(
enumerate(data_loader), total=len(data_loader), desc="Training"
) as tqdm_train:
for episode_index, (
support_images,
support_labels,
query_images,
query_labels,
_,
) in tqdm_train:
optimizer.zero_grad()
model.process_support_set(
support_images.to(DEVICE), support_labels.to(DEVICE)
)
classification_scores = model(query_images.to(DEVICE))
loss = LOSS_FUNCTION(classification_scores, query_labels.to(DEVICE))
loss.backward()
optimizer.step()
all_loss.append(loss.item())
tqdm_train.set_postfix(loss=mean(all_loss))
return mean(all_loss)
best_state = few_shot_classifier.state_dict()
best_validation_accuracy = 0.0
for epoch in range(n_epochs):
print(f"Epoch {epoch}")
average_loss = training_epoch(few_shot_classifier, train_loader, train_optimizer)
validation_accuracy = evaluate(
few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
)
if validation_accuracy > best_validation_accuracy:
best_validation_accuracy = validation_accuracy
best_state = few_shot_classifier.state_dict()
train_scheduler.step()
Epoch 0
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.01it/s, loss=9.31] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:02<00:00, 6.11it/s, accuracy=0.241]
Epoch 1
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:04<00:00, 3.46it/s, loss=9.24] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.61it/s, accuracy=0.242]
Epoch 2
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 3.70it/s, loss=9.77] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.50it/s, accuracy=0.244]
Epoch 3
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.59it/s, loss=9.43] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.73it/s, accuracy=0.246]
Epoch 4
Training: 100%|██████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.41it/s, loss=8.94] Validation: 100%|███████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.52it/s, accuracy=0.247]
Epoch 5
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.38it/s, loss=8.4] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.36it/s, accuracy=0.254]
Epoch 6
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.23it/s, loss=8.79] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.15it/s, accuracy=0.259]
Epoch 7
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.28it/s, loss=8.1] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.16it/s, accuracy=0.262]
Epoch 8
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.36it/s, loss=7.29] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.59it/s, accuracy=0.274]
Epoch 9
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.37it/s, loss=7.32] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.17it/s, accuracy=0.279]
Epoch 10
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.24it/s, loss=7.92] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:07<00:00, 1.95it/s, accuracy=0.282]
Epoch 11
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.55it/s, loss=7.47] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:08<00:00, 1.70it/s, accuracy=0.284]
Epoch 12
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.22it/s, loss=7.38] Validation: 100%|███████████████████████████████████████████████████████| 2000/2000 [00:36<00:00, 2.63s/it, accuracy=0.301]
Epoch 13
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.60it/s, loss=6.56] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.07it/s, accuracy=0.324]
Epoch 14
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.48it/s, loss=6.37] Validation: 100%|███████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.07it/s, accuracy=0.338]
Epoch 15
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.55it/s, loss=6.3] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.07it/s, accuracy=0.346]
Epoch 16
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.61it/s, loss=5.93] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.00it/s, accuracy=0.352]
Epoch 17
Training: 100%|██████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.48it/s, loss=5.08] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.07it/s, accuracy=0.354]
Epoch 18
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.19it/s, loss=5.58] Validation: 100%|███████████████████████████████████████████████████████| 2000/2000 [00:09<00:00, 1.52it/s, accuracy=0.355]
Epoch 19
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.37it/s, loss=5.2] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.12it/s, accuracy=0.369]
Epoch 20
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.37it/s, loss=5.66] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:08<00:00, 1.70it/s, accuracy=0.378]
Epoch 21
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.76it/s, loss=4.56] Validation: 100%|████████████████████████████████████████████████████████| 2000/2000 [00:08<00:00, 1.62it/s, accuracy=0.38]
Epoch 22
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.50it/s, loss=4.66] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:09<00:00, 1.44it/s, accuracy=0.384]
Epoch 23
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.57it/s, loss=4.37] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:07<00:00, 1.85it/s, accuracy=0.393]
Epoch 24
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.58it/s, loss=4.25] Validation: 100%|███████████████████████████████████████████████████████| 2000/2000 [00:07<00:00, 1.95it/s, accuracy=0.396]
Epoch 25
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.61it/s, loss=3.82] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:07<00:00, 1.81it/s, accuracy=0.413]
Epoch 26
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.31it/s, loss=3.89] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:08<00:00, 1.64it/s, accuracy=0.414]
Epoch 27
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.25it/s, loss=3.65] Validation: 100%|███████████████████████████████████████████████████████| 2000/2000 [00:07<00:00, 1.86it/s, accuracy=0.417]
Epoch 28
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.31it/s, loss=3.36] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:09<00:00, 1.50it/s, accuracy=0.423]
Epoch 29
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.34it/s, loss=3.45] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:07<00:00, 1.80it/s, accuracy=0.424]
Epoch 30
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.54it/s, loss=2.84] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:09<00:00, 1.46it/s, accuracy=0.425]
Epoch 31
Training: 100%|██████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.61it/s, loss=2.62] Validation: 100%|████████████████████████████████████████████████████████| 2000/2000 [00:08<00:00, 1.71it/s, accuracy=0.43]
Epoch 32
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.56it/s, loss=2.76] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:07<00:00, 1.78it/s, accuracy=0.432]
Epoch 33
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.71it/s, loss=2.95] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:11<00:00, 1.21it/s, accuracy=0.441]
Epoch 34
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.54it/s, loss=2.87] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:09<00:00, 1.47it/s, accuracy=0.447]
Epoch 35
Training: 100%|██████████████████████████████████████████████████████████████| 2000/2000 [00:07<00:00, 1.99it/s, loss=1.27] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:11<00:00, 1.27it/s, accuracy=0.453]
Epoch 36
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.65it/s, loss=1.59] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:08<00:00, 1.62it/s, accuracy=0.466]
Epoch 37
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.63it/s, loss=1.72] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:09<00:00, 1.44it/s, accuracy=0.467]
Epoch 38
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.50it/s, loss=1.41] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:09<00:00, 1.45it/s, accuracy=0.478]
Epoch 39
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.45it/s, loss=1.98] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:10<00:00, 1.28it/s, accuracy=0.491]
Epoch 40
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.27it/s, loss=1.83] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:08<00:00, 1.59it/s, accuracy=0.495]
Epoch 41
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.66it/s, loss=1.62] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:07<00:00, 1.82it/s, accuracy=0.512]
Epoch 42
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.67it/s, loss=1.62] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:09<00:00, 1.49it/s, accuracy=0.513]
Epoch 43
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.62it/s, loss=1.62] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:08<00:00, 1.75it/s, accuracy=0.513]
Epoch 44
Training: 100%|██████████████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.26it/s, loss=1.53] Validation: 100%|███████████████████████████████████████████████████████| 2000/2000 [00:08<00:00, 1.62it/s, accuracy=0.533]
Epoch 45
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.37it/s, loss=1.4] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:10<00:00, 1.39it/s, accuracy=0.536]
Epoch 46
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.32it/s, loss=1.36] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:11<00:00, 1.23it/s, accuracy=0.537]
Epoch 47
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.51it/s, loss=1.35] Validation: 100%|███████████████████████████████████████████████████████| 2000/2000 [00:07<00:00, 1.89it/s, accuracy=0.54]
Epoch 48
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:06<00:00, 2.27it/s, loss=1.25] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:07<00:00, 1.75it/s, accuracy=0.554]
Epoch 49
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00, 2.52it/s, loss=1.05] Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:08<00:00, 1.72it/s, accuracy=0.564]
few_shot_classifier.load_state_dict(best_state)
<All keys matched successfully>
n_test_tasks = 2000
test_set = PLANT(split="test", training=False)
test_sampler = TaskSampler(
test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_test_tasks
)
test_loader = DataLoader(
test_set,
batch_sampler=test_sampler,
num_workers=n_workers,
pin_memory=True,
collate_fn=test_sampler.episodic_collate_fn,
)
accuracy = evaluate(few_shot_classifier, test_loader, device=DEVICE)
print(f"Average accuracy : {(100 * accuracy):.1f} %")
100%|██████████████████████████████████████████████████████████████████| 2000/2000 [00:07<00:00, 1.76it/s, accuracy=0.559]
Average accuracy : 55.9 %